import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD
from torch.autograd import Variable
import torchvision
import matplotlib.pyplot as plt

from model import ArtNet

from utils import *


""" Dataloaders """
train_loader = DataLoader(
    torchvision.datasets.ImageFolder(train_path, transform=transformer),
    batch_size=64, shuffle=True
)
test_loader = DataLoader(
    torchvision.datasets.ImageFolder(test_path, transform=transformer),
    batch_size=32, shuffle=True
)

""" Training the Model """
def training(model, optimizer, loss, epochs):
    best_accuracy = 0
    for epoch in range(epochs):
        model.train()
        train_accuracy = 0.0
        train_loss = 0.0

        for i, (images, labels) in enumerate(train_loader):
            if torch.cuda.is_available():
                images = Variable(images.cuda())
                labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)
            loss_var = loss(outputs, labels)
            loss_var.backward()
            optimizer.step()

            train_loss += loss_var.cpu().data * images.size(0)
            _, prediction = torch.max(outputs.data, 1)

            train_accuracy += int(torch.sum(prediction == labels.data))

        train_accuracy = train_accuracy / train_count
        train_loss = train_loss / train_count

        model.eval()

        test_accuracy = 0.0
        for i, (images, labels) in enumerate(test_loader):
            if torch.cuda.is_available():
                images = Variable(images.cuda())
                labels = Variable(labels.cuda())

            outputs = model(images)
            _, prediction = torch.max(outputs.data, 1)
            test_accuracy += int(torch.sum(prediction == labels.data))

        test_accuracy = test_accuracy / test_count

        print_training_message(epoch, train_loss, train_accuracy, test_accuracy)

        # Save the best model
        if test_accuracy > best_accuracy:
            torch.save(model.state_dict(), "best_checkpoint.model")
            best_accuracy = test_accuracy

def main():
    """ Defining Model """
    model = ArtNet(num_classes).to(device)
    optimization = SGD(model.parameters(), lr=0.01, momentum=0.95)
    loss = nn.CrossEntropyLoss()

    epochs = 10

    print("Training Starts...")
    print()
    training(model, optimization, loss, epochs)

if __name__ == "__main__":
    main()
    
